"""
Filename: calib_lumpsum.jl
Paper: Unemployment and the Distribution of Liquidity
Authors: Zach Bethune and Guillaume Rocheteau
Contact: bethune@rice.edu
Last Modified: 8/22/2023
Purpose: functions and procedure to calibrate model with wasteful money creation
"""
###############################################################################
# Functions
###############################################################################
#Market clearing used in calibration procedure
function mkt_clear_ss_calibration!(py; mm::model_outcomes,parms::parmsmodel,md::moments_data,W_iter_tol=Wa_tol,Zmd::Real,Ag::Real)
    
    mm.Ag = copy(Ag);  #this will get updated in while loop below
    mm.Rf = copy(md.Rf_moment); #target illiquid real return
    Rf = copy(md.Rf_moment); 
    mm.Rm = copy(md.Rm_moment); #target real return on MZM
    Rm = copy(md.Rm_moment);

    #initialize fiscal loop
    Ag0=copy(mm.Ag); Ag1=copy(mm.Ag);
    Zmd0=copy(Zmd); Zmd1=copy(Zmd);
    mm.Wa=copy(mm.Wa_p); #initial guess of value function
    error=1.0; count_iter=0;

    while error>fiscal_tol && count_iter<fiscal_iter_max #loops over fiscal transfers
        count_iter+=1;
        
        ##########################
        #taxes and transfers
        ##########################
        mm.tau=zeros((Na,2,Nz));
        taulumpsum = ((1.0/Rf)-1.0)*Ag0+ ((1.0/mm.Rm)-1.0)*Zmd0; #lump-sum money creation
        mm.tau .= taulumpsum;
        
        ##########################
        #firm problem
        ##########################
        mm.Ys = κ_prime_inv(py;parms);
        mm.frev = mm.zgrid.*(1.0 + py*mm.Ys - κ(mm.Ys;parms));
        mm.wages_bar = zeros((2,Nz));
        mm.wages_bar[1,:] = md.LABOR_SHARE.*mm.frev;
        mm.wages_bar[2,:] = md.REPLACE_RATE.*mm.wages_bar[1,:] 
        mm.wages = zeros((Na,2,Nz));
        for j=1:Nz
            for i=1:2
                mm.wages[:,i,j] .= mm.wages_bar[i,j].+mm.tau[:,i,j];
            end
        end
        profits=mm.frev.-mm.wages_bar[1,:]
        mm.Js = (profits)./(1.0-((1.0-parms.DELTA)/Rf))

        #calibrate entry costs
        parms.K = (mm.Js[1]*md.JOB_FILLING)/(Rf*mm.zgrid[1])
        
        #steady state tightness and employment rate
        if maximum((mm.zgrid.*parms.K*Rf)./mm.Js) < 1
            mm.theta = q_inv(((mm.zgrid.*parms.K*Rf)./mm.Js)[1];parms)
        else
            mm.theta = 0.0001
        end
        mm.emp = λ(mm.theta;parms)/(parms.DELTA+λ(mm.theta;parms));

        ##########################
        #household problem
        ##########################
        #solve for decisions rules and value function
        Wa_temp=deepcopy(mm.Wa);
        Wa_solve!(mm;tol=W_iter_tol,py,Rf,Wa_p=Wa_temp,mm.agrid,parms); 
        
        ##########################
        #solve for transition matrix T and distributions
        ##########################
        transition_g!(mm;py,parms); #solve for transition matrix T
        mm.g=zeros((Na,2,Nz));
        (lam,ghat)=powm!(mm.T,rand(Na*2*Nz),tol=1e-10,maxiter=1000000);
        mm.g = reshape(ghat,(Na,2,Nz))
        for j=1:Nz
            mm.g[:,:,j] = mm.g[:,:,j]*(mm.gz[j]/sum(mm.g[:,:,j]))
        end
        mm.ga = sum(mm.g,dims=3)[:,:,1]

        ##########################
        #compute market clearing condition wrt py 
        ##########################
        #early consumption goods 
        mm.Yd=0.0
        for l=1:Nz
            for k=1:2
                for kk=1:2
                    ystar_int = LinearInterpolation(mm.agrid, mm.ystar[:,kk,l], extrapolation_bc=Line());
                    for j=1:Na
                        mm.Yd += mm.PI[k,kk]*parms.ALPHA*parms.ALPHAmf*ymf(ystar_int,mm.a_p[j,k,l],mm.am_p[j,k,l],py)[1]*mm.g[j,k,l]
                        mm.Yd += mm.PI[k,kk]*parms.ALPHA*(1.0-parms.ALPHAmf)*ym(ystar_int,mm.a_p[j,k,l],mm.am_p[j,k,l],py)[1]*mm.g[j,k,l]
                    end
                end
            end
        end
        
        ##########################
        #compute aggregate illiquid wealth demand and update bond supply to clear market
        ##########################
        mm.Jd = sum(mm.g.*(mm.a_p.-mm.am_p))
        Ag1 =  mm.Jd - mm.emp*sum(mm.gz.*mm.Js)

        ##########################
        #update real money demand
        ##########################
        Zmd1 = sum(mm.g.*mm.am_p);

        ##########################
        #compute error and update
        ##########################
        error = max(abs(Ag1.-Ag0),abs(Zmd1.-Zmd0));
        Ag0=copy(Ag1);
        Zmd0=copy(Zmd1);
        mm.Ag = copy(Ag1);
        mm.Zmd = copy(Zmd1);
    end
    
    #record prices
    mm.py = copy(py)
    mm.Rf = copy(Rf)

    #excess supply of early consumption
    z = mm.emp*sum(mm.zgrid.*mm.gz)*mm.Ys - mm.Yd
    return z
end

#Function to calibrate subset of parameters and report SSR of liquid wealth and liquid share distributions
function SSR_calib!(sub_parms;mminner::model_outcomes,parinner::parmsmodel,md::moments_data,mmts::model_moments)

    if sub_parms[1]<(0.94^(1/12)) || sub_parms[1]>(0.95^(1/12))
        SSR=1e4
    elseif sub_parms[2]<0.2 || sub_parms[2]>0.35
        SSR=1e4
    elseif sub_parms[3]<2.0 || sub_parms[3]>3.0
        SSR=1e4
    elseif sub_parms[4]<0.05 || sub_parms[4]>0.10
        SSR=1e4
    elseif sub_parms[5]<0.02 || sub_parms[5]>0.10
        SSR=1e4
    else
        ####################################################
        #update parinner struct using sub_parms
        ####################################################
        parinner.BETA = deepcopy(sub_parms[1]);
        parinner.a = deepcopy(sub_parms[2]);
        parinner.A = deepcopy(sub_parms[3]);
        parinner.ALPHA = deepcopy(sub_parms[4]);
        parinner.ALPHAmf = deepcopy(sub_parms[5]);

        ####################################################
        #solve for steady-state equilibrium and internally update (K,Ag,b)
        ####################################################
        #initial guess of value functions
        mminner.Wa_p[:,1,:] .= u_prime.(mminner.agrid; parms=parinner);
        mminner.Wa_p[:,2,:] .= 2.0.*u_prime.(mminner.agrid; parms=parinner);

        #solve for steady-state equilibrium
        find_zero(x->mkt_clear_ss_calibration!(x;mm=mminner,parms=parinner,md=md,W_iter_tol=Wa_tol,Zmd=mminner.Zmd,Ag=mminner.Ag),(0.5,0.8),Roots.Brent(),verbose=false,atol=mkt_clear_calib_inner_tol)

        ####################################################
        #compute model moments
        ####################################################
        mmts.G, mmts.G1, mmts.G0, mmts.gmi, mmts.Gmi, mmts.li_range, mmts.gls, mmts.Gls, mmts.ls_range, mmts.gwi, mmts.Gwi, mmts.wi_range, mmts.gini, mmts.liquid_share, mmts.liquid_to_income, mmts.gbonds_share, mmts.equity_share, mmts.wealth_to_income, mmts.liquid_prem = other_mmts(;mm=mminner,parms=parinner);
        
        ####################################################
        #Compute SSR of liquid wealth and liquid share distributions
        ####################################################
        #load liquid wealth to income distribution
        lid_data=CSV.read("../empirical_data/liquid_income_dist.csv",DataFrame);
        lid_data_int = LinearInterpolation(Interpolations.deduplicate_knots!(lid_data[:,1]), Interpolations.deduplicate_knots!(lid_data[:,2]), extrapolation_bc=Line());
        function lid_data_int_ext(x)
            if x>maximum(lid_data[:,1])
                return 1.0
            else
                return lid_data_int.(x)
            end
        end

        #load liquid share of wealth distribution
        lsd_data=CSV.read("../empirical_data/liquid_share_dist.csv",DataFrame);
        lsd_data_int = LinearInterpolation(Interpolations.deduplicate_knots!(lsd_data[:,1]), Interpolations.deduplicate_knots!(lsd_data[:,2]), extrapolation_bc=Line());
        function lsd_data_int_ext(x)
            if x>maximum(lsd_data[:,1])
                return 1.0
            else
                return lsd_data_int.(x)
            end
        end

        #compute SSR
        SSR=sum((lid_data_int_ext.(mmts.li_range).-mmts.Gmi).^2);
        SSR += sum((lsd_data_int_ext.(mmts.ls_range).-mmts.Gls).^2);
    end

    return SSR
end

###############################################################################
# Define parameter structure and exogenously chosen parameters
###############################################################################
parms_base=parmsmodel()
parms_base.c = 1.6; #curvature of matching function
parms_base.C = 1.0; #level of matching function
parms_base.DELTA = copy(md.JOB_SEPARATION); #job destruction rate
parms_base.ELL = 0.0; #utiilty of leisure
parms_base.D = 1.0; #level of late consumption utility
parms_base.d = 1.5; #curvatue of late consumption utility
parms_base.π_l = 0.6; #share of low skilled
parms_base.π_m = 0.3; #share of middle skilled
parms_base.meanZ = 1.0; #average labor productivity 

mm.Rm=copy(md.Rm_moment);
mm.Rf=copy(md.Rf_moment);

###############################################################################
# Productivity distribution
###############################################################################
mm.gz=zeros(Nz);
mm.gz[1]=parms_base.π_l; 
mm.gz[2]=parms_base.π_m; 
mm.gz[3]=1.0-mm.gz[1]-mm.gz[2];
mm.zgrid=zeros(Nz); 
mm.zgrid[3] = parms_base.meanZ*md.TOP_INC_SHARE/mm.gz[3]
mm.zgrid[1] = (parms_base.meanZ-mm.gz[3]*mm.zgrid[3])/(mm.gz[2]*md.SKILL_PREMIUM+mm.gz[1])
mm.zgrid[2] = md.SKILL_PREMIUM*mm.zgrid[1]

###############################################################################
# Set initial guess for other parameters in calibration
###############################################################################
parms_base.BETA = 0.952^(1/12); #monthly discount factor
parms_base.rho=(1.0/parms_base.BETA)-1.0;
parms_base.a = 0.2; #early-stage utility curvature parameter
parms_base.A = 2.1; #early-stage utility level parameter
parms_base.K = 0.02; #fixed entry cost
parms_base.b = 0.75; #cost curvature parameter
parms_base.ALPHA = 0.11; #probability of expenditure shock
parms_base.ALPHAmf = 0.36; #acceptability of (partially) illiquid assets 

###############################################################################
# Initalize structure to hold model moments computed in calibration routine
###############################################################################
print("\n\nCalibrating model:")
mmts=model_moments();
parms_guess = [0.99556 0.2 2.1 0.11 0.36];
output_calib = optimize(x->SSR_calib!(x;mminner=mm,parinner=parms_base,mmts=mmts,md=md), parms_guess, NelderMead());
SSR_calib!(minimizer(output_calib),mminner=mm,parinner=parms_base,mmts=mmts,md=md);
save("../model_data_lumpsum/base_calib.jld2","parms_base",parms_base,"mm",mm,"md",md,"mmts",mmts);
print("\complete")

#END